import os
import time

import gymnasium as gym
import numpy as np
import torch
import torch.nn.functional as F

from algorithms.a2c import A2C_RL2
from algorithms.online_storage import OnlineStorageRL2
from environments.parallel_envs import make_vec_envs
from models.policy_net import NoisyActorCriticRNN, ActorCriticRNN, SharedActorCriticRNN, nmActorCriticRNN
from utils import helpers as utl
from utils.tb_logger import TBLogger

from utils.evaluation import get_empirical_returns


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class MetaLearner:
    """
    Meta-Learner class with the main training loop for RL2
    """

    def __init__(self, args):

        self.args = args
        self.iter_idx = -1
        utl.seed(self.args.seed, self.args.deterministic_execution)

        # -- initialize environments --
        self.envs = make_vec_envs(
            env_name=self.args.env_name,
            max_episode_steps=self.args.max_episode_steps,
            num_processes=self.args.num_processes
        )
        print(f'initialize envs done: {self.envs}')

        # -- get dimensions: state, action, reward --
        # state
        assert isinstance(self.envs.single_observation_space, gym.spaces.Dict)
        # for dynamic_foraging environments
        if self.args.env_name.split('-')[0] in ['CoupledBlockDF', 'UncoupledBlockDF', 'RandomWalkDF']:
            self.args.state_dim = self.envs.single_observation_space['trial'].shape[0]
        # for Timed-based Bandit environments
        elif self.args.env_name.split('-')[0] in ['TimedBlockBandit2ArmCoupledEasy', 'TimedBlockBandit2ArmCoupledMultipleProb']:
            self.args.state_dim = self.envs.single_observation_space['timestep'].shape[0] + \
                                  self.envs.single_observation_space['go_cue'].n-1
        # for Trial-based Bandit environments
        else:
            self.args.state_dim = self.envs.single_observation_space['timestep'].shape[0]
        
        # whether to use time as state to policy
        if self.args.time_as_state:  # if time as input to the policy
            self.args.input_state_dim_for_policy = self.args.state_dim
        else:  # not use time as input to the policy
            self.args.input_state_dim_for_policy = self.args.state_dim - 1
        # action
        if isinstance(self.envs.single_action_space, gym.spaces.discrete.Discrete):
            self.args.action_space_type = 'Discrete'
            self.args.action_dim = self.envs.single_action_space.n
        else:
            raise NotImplementedError(f'action space currently only supports Discrete')
        # reward
        self.args.reward_dim = 1

        # -- initialize tensorboard logger and save args --
        self.logger = TBLogger(self.args, self.args.exp_label)
        # print(f'logging under: {self.logger.full_output_folder}')

        # -- model initialization --
        if self.args.exp_label == 'rl2':
            # initialize policy
            self.policy = self.initialize_policy_rl2()
            self.policy_storage = OnlineStorageRL2(self.args)
            # print(f'policy net: {self.policy.policy_net}')
        elif self.args.exp_label == 'noisy_rl2':
            # initialize policy
            self.policy = self.initialize_policy_noisy_rl2()
            self.policy_storage = OnlineStorageRL2(self.args)
            # print(f'policy net: {self.policy.policy_net}')
        else:
            raise NotImplementedError


    def initialize_policy_rl2(self):
        # initialize policy network
        if self.args.NMd:
            actor_critic = nmActorCriticRNN(
                args=self.args,
                layers_before_rnn=self.args.layers_before_rnn,
                rnn_hidden_dim=self.args.rnn_hidden_dim,
                layers_after_rnn=self.args.layers_after_rnn,
                activation_function=self.args.policy_net_activation_function,
                rnn_cell_type=self.args.rnn_cell_type,  #make sure this is nm
                initialisation_method=self.args.policy_net_initialization_method,
                state_dim=self.args.input_state_dim_for_policy,
                state_embed_dim=self.args.state_embed_dim,
                action_dim=self.args.action_dim,
                action_embed_dim=self.args.action_embed_dim,
                action_space_type=self.args.action_space_type,
                reward_dim=self.args.reward_dim,
                reward_embed_dim=self.args.reward_embed_dim,
                N_nm = self.args.nNM
            ).to(device)
        elif not self.args.shared_rnn:
            actor_critic = ActorCriticRNN(
                args=self.args,
                layers_before_rnn=self.args.layers_before_rnn,
                rnn_hidden_dim=self.args.rnn_hidden_dim,
                layers_after_rnn=self.args.layers_after_rnn,
                activation_function=self.args.policy_net_activation_function,
                rnn_cell_type=self.args.rnn_cell_type,
                initialization_method=self.args.policy_net_initialization_method,
                state_dim=self.args.input_state_dim_for_policy,
                state_embed_dim=self.args.state_embed_dim,
                action_dim=self.args.action_dim,
                action_embed_dim=self.args.action_embed_dim,
                action_space_type=self.args.action_space_type,
                reward_dim=self.args.reward_dim,
                reward_embed_dim=self.args.reward_embed_dim
            ).to(device)
        elif self.args.shared_rnn:
            actor_critic = SharedActorCriticRNN(
                args=self.args,
                layers_before_rnn=self.args.layers_before_rnn,
                rnn_hidden_dim=self.args.rnn_hidden_dim,
                layers_after_rnn=self.args.layers_after_rnn,
                activation_function=self.args.policy_net_activation_function,
                rnn_cell_type=self.args.rnn_cell_type,
                initialization_method=self.args.policy_net_initialization_method,
                state_dim=self.args.input_state_dim_for_policy,
                state_embed_dim=self.args.state_embed_dim,
                action_dim=self.args.action_dim,
                action_embed_dim=self.args.action_embed_dim,
                action_space_type=self.args.action_space_type,
                reward_dim=self.args.reward_dim,
                reward_embed_dim=self.args.reward_embed_dim
            ).to(device)
        else:
            raise ValueError(f'invalid args.shared_rnn: {self.args.shared_rnn}')
        
        # initialize policy trainer
        if self.args.policy_algorithm == 'a2c':
            policy = A2C_RL2(
                args=self.args,
                actor_critic=actor_critic,
                critic_loss_coeff=self.args.policy_critic_loss_coeff,
                entropy_loss_coeff=self.args.policy_entropy_loss_coeff,
                activity_l2_loss_coeff=self.args.policy_activity_l2_loss_coeff,
                policy_optimizer=self.args.policy_optimizer,
                policy_eps=self.args.policy_eps,
                policy_lr=self.args.policy_lr,
                policy_anneal_lr=self.args.policy_anneal_lr,
                train_steps=self.args.num_updates,
            )
        else:
            raise NotImplementedError

        return policy
            
    
    def initialize_policy_noisy_rl2(self):
        # initialize policy network
        if not self.args.shared_rnn:
            actor_critic = NoisyActorCriticRNN(
                args=self.args,
                layers_before_rnn=self.args.layers_before_rnn,
                rnn_hidden_dim=self.args.rnn_hidden_dim,
                layers_after_rnn=self.args.layers_after_rnn,
                activation_function=self.args.policy_net_activation_function,
                rnn_cell_type=self.args.rnn_cell_type,
                initialization_method=self.args.policy_net_initialization_method,
                state_dim=self.args.input_state_dim_for_policy,
                state_embed_dim=self.args.state_embed_dim,
                action_dim=self.args.action_dim,
                action_embed_dim=self.args.action_embed_dim,
                action_space_type=self.args.action_space_type,
                reward_dim=self.args.reward_dim,
                reward_embed_dim=self.args.reward_embed_dim,
                hidden_noise_std=self.args.hidden_noise_std
            ).to(device)
        # TODO
        # elif self.args.shared_rnn:
        #     actor_critic = SharedActorCriticRNN(
        #         args=self.args,
        #         layers_before_rnn=self.args.layers_before_rnn,
        #         rnn_hidden_dim=self.args.rnn_hidden_dim,
        #         layers_after_rnn=self.args.layers_after_rnn,
        #         activation_function=self.args.policy_net_activation_function,
        #         rnn_cell_type=self.args.rnn_cell_type,
        #         initialization_method=self.args.policy_net_initialization_method,
        #         state_dim=self.args.input_state_dim_for_policy,
        #         state_embed_dim=self.args.state_embed_dim,
        #         action_dim=self.args.action_dim,
        #         action_embed_dim=self.args.action_embed_dim,
        #         action_space_type=self.args.action_space_type,
        #         reward_dim=self.args.reward_dim,
        #         reward_embed_dim=self.args.reward_embed_dim
        #     ).to(device)
        else:
            raise ValueError(f'invalid args.shared_rnn: {self.args.shared_rnn}')
        
        # initialize policy trainer
        if self.args.policy_algorithm == 'a2c':
            policy = A2C_RL2(
                args=self.args,
                actor_critic=actor_critic,
                critic_loss_coeff=self.args.policy_critic_loss_coeff,
                entropy_loss_coeff=self.args.policy_entropy_loss_coeff,
                activity_l2_loss_coeff=self.args.policy_activity_l2_loss_coeff,
                policy_optimizer=self.args.policy_optimizer,
                policy_eps=self.args.policy_eps,
                policy_lr=self.args.policy_lr,
                policy_anneal_lr=self.args.policy_anneal_lr,
                train_steps=self.args.num_updates,
            )
        else:
            raise NotImplementedError

        return policy


    def train(self):
        """ Main Meta-Training loop """

        # loss
        train_stats = {
            'episode_returns': [],
            'actor_losses': [],
            'critic_losses': [],
            'policy_entropies': [],
            'activity_l2_loss': []
        }
        # evaluation: pass to log function
        evaluation_stats = {
            'eval_epoch_ids': [],
            'empirical_return_avgs': [],
            'empirical_return_stds': []
        }

        # log once before training
        with torch.no_grad():
            self.log(evaluation_stats)

        # training starts
        for self.iter_idx in range(int(self.args.num_updates)):
            print(f'training epoch: {self.iter_idx}')

            # -- COLLECT DATA -- #
            # reset all envs
            curr_states_dict, infos = self.envs.reset(seed=self.args.seed, options={})
            curr_states = utl.get_states_from_state_dicts(
                curr_states_dict, self.args.env_name, True)
            curr_states_for_policy = utl.get_states_from_state_dicts(curr_states_dict, self.args.env_name, self.args.time_as_state)
            curr_states = torch.from_numpy(curr_states).float().\
                reshape((1, self.args.num_processes, self.args.state_dim)).to(device)
            curr_states_for_policy = torch.from_numpy(curr_states_for_policy).float().\
                reshape((1, self.args.num_processes, self.args.input_state_dim_for_policy)).to(device)

            prev_actions = torch.zeros(1, self.args.num_processes, self.args.action_dim).to(device)
            prev_rewards = torch.zeros(1, self.args.num_processes, self.args.reward_dim).to(device)
            
            # initialize rnn hidden states
            if self.args.exp_label in ['rl2', 'noisy_rl2']:
                if self.args.NMd:
                    rnn_prev_hidden_states = torch.zeros(1, self.args.num_processes, self.args.rnn_hidden_dim + self.args.nNM).to(device)
                elif not self.args.shared_rnn:
                    # initialize ActorCriticRNN hidden states
                    actor_prev_hidden_states = torch.zeros(1, self.args.num_processes, self.args.rnn_hidden_dim).to(device)
                    critic_prev_hidden_states = torch.zeros(1, self.args.num_processes, self.args.rnn_hidden_dim).to(device)
                elif self.args.shared_rnn:
                    # initialize SharedActorCriticRNN hidden states
                    rnn_prev_hidden_states = torch.zeros(1, self.args.num_processes, self.args.rnn_hidden_dim).to(device)
            else:
                raise ValueError

            #print(rnn_prev_hidden_states.shape)

            # insert initial data to policy_storage
            if self.args.exp_label in ['rl2', 'noisy_rl2']:
                if self.args.NMd:
                    self.policy_storage.insert_initial(
                        states=curr_states.squeeze(0),
                        states_for_policy=curr_states_for_policy.squeeze(0),
                        actions=prev_actions.squeeze(0),
                        rewards=prev_rewards.squeeze(0),
                        actor_hidden_states=rnn_prev_hidden_states.squeeze(0),
                        critic_hidden_states=rnn_prev_hidden_states.squeeze(0)
                    )
                elif not self.args.shared_rnn:
                    self.policy_storage.insert_initial(
                        states=curr_states.squeeze(0),
                        states_for_policy=curr_states_for_policy.squeeze(0),
                        actions=prev_actions.squeeze(0),
                        rewards=prev_rewards.squeeze(0),
                        actor_hidden_states=actor_prev_hidden_states.squeeze(0),
                        critic_hidden_states=critic_prev_hidden_states.squeeze(0)
                    )
                elif self.args.shared_rnn:
                    # if shared_rnn, then both actor_hidden_states and
                    # critic_hidden_states are  are rnn_hidden_states
                    self.policy_storage.insert_initial(
                        states=curr_states.squeeze(0),
                        states_for_policy=curr_states_for_policy.squeeze(0),
                        actions=prev_actions.squeeze(0),
                        rewards=prev_rewards.squeeze(0),
                        actor_hidden_states=rnn_prev_hidden_states.squeeze(0),
                        critic_hidden_states=rnn_prev_hidden_states.squeeze(0)
                    )
            else:
                raise ValueError(f'incompatible model type: {self.args.exp_label}')

            # rollout current policy for n steps in parallel environments
            for step in range(self.args.policy_num_steps_per_update):
                # print(f'step: {step}')
                # print(f' curr_states: {curr_states.shape}, {curr_states.requires_grad}')
                # print(f' prev_actions: {prev_actions.shape}, {prev_actions.requires_grad}')
                # print(f' prev_rewards: {prev_rewards.shape}, {prev_actions.requires_grad}')
                # print(f' actor_prev_hidden_states: {actor_prev_hidden_states.shape}, {actor_prev_hidden_states.requires_grad}')
                # print(f' critic_prev_hidden_states: {critic_prev_hidden_states.shape}, {critic_prev_hidden_states.requires_grad}')
                
                # sample actions from policy: act based on current policy
                with torch.no_grad():
                    if self.args.exp_label in ['rl2', 'noisy_rl2']:
                        if self.args.NMd:
                            actions_categorical, action_log_probs, entropy, state_values, \
                                rnn_hidden_states = \
                                    self.policy.actor_critic.act(
                                        curr_states=curr_states_for_policy,
                                        prev_actions=prev_actions,
                                        prev_rewards=prev_rewards,
                                        rnn_prev_hidden_states=rnn_prev_hidden_states,
                                        return_prior=False, 
                                        deterministic=self.args.deterministic_policy)
                        elif not self.args.shared_rnn:
                            actions_categorical, action_log_probs, entropy, state_values, \
                                actor_hidden_states, critic_hidden_states = \
                                    self.policy.actor_critic.act(
                                        curr_states=curr_states_for_policy,
                                        prev_actions=prev_actions,
                                        prev_rewards=prev_rewards,
                                        actor_prev_hidden_states=actor_prev_hidden_states, 
                                        critic_prev_hidden_states=critic_prev_hidden_states,
                                        return_prior=False, 
                                        deterministic=self.args.deterministic_policy)
                        elif self.args.shared_rnn:
                            actions_categorical, action_log_probs, entropy, state_values, \
                                rnn_hidden_states = \
                                    self.policy.actor_critic.act(
                                        curr_states=curr_states_for_policy,
                                        prev_actions=prev_actions,
                                        prev_rewards=prev_rewards,
                                        rnn_prev_hidden_states=rnn_prev_hidden_states,
                                        return_prior=False, 
                                        deterministic=self.args.deterministic_policy)

                    else:
                        raise ValueError(f'incompatible model type: {self.args.exp_label}')
                # print(f' actions_categorical: {actions_categorical} {actions_categorical.requires_grad}')
                # print(f' action_log_probs: {action_log_probs.shape} {action_log_probs.requires_grad}')
                # print(f' entropy: {entropy.shape} {entropy.requires_grad}')
                # print(f' state_values: {state_values.shape} {state_values.requires_grad}')
                # print(f' actor_hidden_states: {actor_hidden_states.shape} {actor_hidden_states.requires_grad}')
                # print(f' critic_hidden_states: {critic_hidden_states.shape} {critic_hidden_states.requires_grad}')
                
                # perform the action A_{t} in the environment to get S_{t+1} and R_{t+1}
                next_states_dict, rewards, terminated, truncated, infos = self.envs.step(
                    actions_categorical.squeeze(0).cpu().numpy())
                # cast states and rewards to tensor (step, batch, feature)
                next_states = utl.get_states_from_state_dicts(next_states_dict, self.args.env_name, True)
                next_states_for_policy = utl.get_states_from_state_dicts(next_states_dict, self.args.env_name, self.args.time_as_state)
                next_states = torch.from_numpy(next_states).float().\
                    reshape((1, self.args.num_processes, self.args.state_dim)).to(device)
                next_states_for_policy = torch.from_numpy(next_states_for_policy).float().\
                    reshape((1, self.args.num_processes, self.args.input_state_dim_for_policy)).to(device)

                rewards = torch.from_numpy(rewards).float()\
                    .reshape(1, self.args.num_processes, self.args.reward_dim).to(device)
                # create masks for return calculation:
                # for each env 1 if the episode is ongoing and 0 if it is terminated (not by truncation!)
                masks_ongoing = torch.tensor([not term for term in terminated]).float()\
                    .reshape(1, self.args.num_processes, 1).to(device)
                # print(f' next_states: {next_states.shape}')
                # print(f' rewards: {rewards.shape}')
                # print(f' masks_ongoing: {masks_ongoing.shape}')

                # update inputs for next step
                # note: ActorCriticRNN takes one hot coded action
                actions = F.one_hot(actions_categorical, num_classes=self.args.action_dim)\
                    .float().reshape((1, self.args.num_processes, self.args.action_dim))
                # print(f' actions: {actions.shape} {actions.requires_grad}')
                curr_states = next_states.to(device)
                curr_states_for_policy = next_states_for_policy.to(device)
                prev_actions = actions.to(device)
                prev_rewards = rewards.to(device)
                if self.args.exp_label in ['rl2', 'noisy_rl2']:
                    if not self.args.shared_rnn:
                        actor_prev_hidden_states = actor_hidden_states.to(device)
                        critic_prev_hidden_states = critic_hidden_states.to(device)
                    elif self.args.shared_rnn:
                        rnn_prev_hidden_states = rnn_hidden_states.to(device)
                else:
                    raise ValueError
                
                # insert experience to policy/ vae storage
                if self.args.exp_label in ['rl2', 'noisy_rl2']:
                    if not self.args.shared_rnn:
                        self.policy_storage.insert(
                            states=curr_states.squeeze(0),
                            states_for_policy=curr_states_for_policy.squeeze(0),
                            actions=prev_actions.squeeze(0),
                            action_log_probs=action_log_probs.reshape(self.args.num_processes, 1),
                            rewards=prev_rewards.squeeze(0),
                            actor_hidden_states=actor_prev_hidden_states.squeeze(0),
                            critic_hidden_states=critic_prev_hidden_states.squeeze(0),
                            state_values=state_values.squeeze(0),
                            masks_ongoing=masks_ongoing.squeeze(0)
                        )
                    elif self.args.shared_rnn:
                        # again, if shared_rnn, then actor_hidden_states are just rnn_hidden_states,
                        # and critic_hidden_states are empty
                        self.policy_storage.insert(
                            states=curr_states.squeeze(0),
                            states_for_policy=curr_states_for_policy.squeeze(0),
                            actions=prev_actions.squeeze(0),
                            action_log_probs=action_log_probs.reshape(self.args.num_processes, 1),
                            rewards=prev_rewards.squeeze(0),
                            actor_hidden_states=rnn_prev_hidden_states.squeeze(0),
                            critic_hidden_states=rnn_prev_hidden_states.squeeze(0),
                            state_values=state_values.squeeze(0),
                            masks_ongoing=masks_ongoing.squeeze(0)
                        )
                else:
                    raise ValueError(f'incompatible model type: {self.args.exp_label}')


            # -- UPDATE POLICY --
            # compute return
            self.policy_storage.compute_returns(
                self.args.policy_gamma, 
                self.args.policy_use_gae,
                self.args.policy_lambda
            )

            # compute loss
            # policy loss
            loss, actor_loss, critic_loss, policy_entropy, activity_l2_loss = \
                self.policy.get_losses(self.policy_storage)
            print(f'policy_loss: {loss}')
            print(f'actor_loss: {actor_loss}')
            print(f'critic_loss: {critic_loss}')
            print(f'policy_entropy: {policy_entropy}')
            print(f'activity_l2_loss: {activity_l2_loss}')
            
            # update parameters
            self.policy.update_parameters(loss)


            # -- LOG --
            with torch.no_grad():
                # log the losses and entropy
                episode_return = self.policy_storage.rewards[1:, :, :].squeeze().sum(dim=0).mean()
                print(f'episode_return: {episode_return}')
                train_stats['episode_returns'].append(episode_return.detach().cpu().numpy())
                train_stats['actor_losses'].append(actor_loss.detach().cpu().numpy())
                train_stats['critic_losses'].append(critic_loss.detach().cpu().numpy())
                train_stats['policy_entropies'].append(policy_entropy.detach().cpu().numpy())
                train_stats['activity_l2_loss'].append(activity_l2_loss.detach().cpu().numpy())

            # evaluation
            with torch.no_grad():
                self.log(evaluation_stats)

            # clear up running storage after update
            self.policy_storage.after_update()
        
        return train_stats, evaluation_stats
        

    def log(self, evaluation_stats):
        # --- evaluate policy ----
        # if (self.iter_idx + 1) % self.args.eval_interval == 0:
        if (self.iter_idx+1) in self.args.eval_ids:
            print(f'EVALUATION: epoch {self.iter_idx}')
            evaluation_stats['eval_epoch_ids'].append(self.iter_idx)

            if self.args.exp_label in ['rl2', 'noisy_rl2']:
                encoder = None
                policy_network = self.policy.actor_critic
            else:
                raise ValueError(f'incompatible model type: {self.args.exp_label}')
                
            # for all environments: get empirical return
            num_eval_envs = self.args.num_eval_envs
            empirical_return_avg, empirical_return_std = get_empirical_returns(
                env_name=self.args.env_name,
                args=self.args,
                encoder=encoder,  # None if rl2
                policy_network=policy_network,
                num_envs=num_eval_envs
            )
            evaluation_stats['empirical_return_avgs'].append(empirical_return_avg)
            evaluation_stats['empirical_return_stds'].append(empirical_return_std)

        # --- save models ---
        if (self.iter_idx + 1) % self.args.save_interval == 0:
            print(f'SAVING MODEL: epoch {self.iter_idx}')
            save_path = self.logger.full_output_folder
            # save_path = os.path.join(self.logger.full_output_folder, 'models')
            # if not os.path.exists(save_path):
            #     os.mkdir(save_path)
            print(f'save_path: {save_path}')

            idx_labels = ['']
            if self.args.save_intermediate_models:
                idx_labels.append(int(self.iter_idx))

            for idx_label in idx_labels:
                # save model
                actor_critic_path = os.path.join(save_path, f'actor_critic_weights{idx_label}.h5')
                torch.save(self.policy.actor_critic.state_dict(), actor_critic_path)
                